package com.yitong.framework.utils;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;

/**
*@author www.zuidaima.com
**/
public class HtmlFilter implements Filter {

	public static final String codePatternStr = "(<|&lt;)p-re[\\w\\W]*?[\\w\\W]*?/p-re[\\w\\W]*?(>|&gt;)";
	public static final Pattern codePattern = Pattern.compile(codePatternStr,
			Pattern.CASE_INSENSITIVE);

	private FilterConfig filterConfig = null;

	class CodeFragment {

		private String left;
		private String fragment;
		private String right;

		public CodeFragment(String left, String fragment, String right) {
			this.left = left;
			this.fragment = fragment;
			this.right = right;
		}

		public String getLeft() {
			return left;
		}

		public void setLeft(String left) {
			this.left = left;
		}

		public String getFragment() {
			return fragment;
		}

		public void setFragment(String fragment) {
			this.fragment = fragment;
		}

		public String getRight() {
			return right;
		}

		public void setRight(String right) {
			this.right = right;
		}
	}

	private static class ByteArrayServletStream extends ServletOutputStream {
		ByteArrayOutputStream baos;

		ByteArrayServletStream(ByteArrayOutputStream baos) {
			this.baos = baos;
		}

		public void write(int param) throws IOException {
			baos.write(param);
		}
	}

	private static class ByteArrayPrintWriter {

		private ByteArrayOutputStream baos = new ByteArrayOutputStream();

		private PrintWriter pw = new PrintWriter(baos);

		private ServletOutputStream sos = new ByteArrayServletStream(baos);

		public PrintWriter getWriter() {
			return pw;
		}

		public ServletOutputStream getStream() {
			return sos;
		}

		byte[] toByteArray() {
			return baos.toByteArray();
		}
	}

	public class CharResponseWrapper extends HttpServletResponseWrapper {
		private ByteArrayPrintWriter output;
		private boolean usingWriter;

		public CharResponseWrapper(HttpServletResponse response) {
			super(response);
			usingWriter = false;
			output = new ByteArrayPrintWriter();
		}

		public byte[] getByteArray() {
			return output.toByteArray();
		}

		@Override
		public ServletOutputStream getOutputStream() throws IOException {
			// will error out, if in use
			if (usingWriter) {
				super.getOutputStream();
			}
			usingWriter = true;
			return output.getStream();
		}

		@Override
		public PrintWriter getWriter() throws IOException {
			// will error out, if in use
			if (usingWriter) {
				super.getWriter();
			}
			usingWriter = true;
			return output.getWriter();
		}

		public String toString() {
			return output.toString();
		}
	}

	public void doFilter(ServletRequest request, ServletResponse response,
			FilterChain chain) throws IOException, ServletException {
		HttpServletRequest _request = (HttpServletRequest) request;
		if (_request.getRequestURI().indexOf("update") != -1
				|| _request.getRequestURI().indexOf("create") != -1) {
			chain.doFilter(request, response);
			return;
		}
		CharResponseWrapper wrappedResponse = new CharResponseWrapper(
				(HttpServletResponse) response);
		chain.doFilter(request, wrappedResponse);
		byte[] bytes = wrappedResponse.getByteArray();

		String contentType = wrappedResponse.getContentType();
		if (contentType != null && contentType.matches(".*?(html|json).*?")) {
			String out = new String(bytes);
			Matcher matcher = codePattern.matcher(out);
			List<CodeFragment> codeFragments = new ArrayList<CodeFragment>();
			while (matcher.find()) {
				String fragment = matcher.group(0);
				String left = matcher.group(1);
				String right = matcher.group(2);
				CodeFragment codeFragment = new CodeFragment(left, fragment,
						right);
				codeFragments.add(codeFragment);
				// 占位符<pr-e>idx</pr-e>
				out = out.replace(fragment, left + "p-re" + right
						+ codeFragments.size() + left + "/p-re" + right);
			}
			out = out.replaceAll("[\r\n]", "").replaceAll(">\\s*?<", "><")
					.trim();
			// 还原占位符
			for (int i = 0; i < codeFragments.size(); i++) {
				CodeFragment codeFragment = codeFragments.get(i);
				String fragment = codeFragment.getFragment();
				String left = codeFragment.getLeft();
				String right = codeFragment.getRight();
				out = out.replace(left + "p-re" + right + (i + 1) + left
						+ "/p-re" + right, fragment);
			}
			response.getOutputStream().write(out.getBytes());
		} else {
			response.getOutputStream().write(bytes);
		}
	}

	@Override
	public void init(FilterConfig filterConfig) throws ServletException {
		this.filterConfig = filterConfig;
	}

	@Override
	public void destroy() {
		filterConfig = null;
	}
}